自定义回调函数和学习率调度器¶
Note
回调函数和学习率调度器也是可以自定义的,本节我们举例说明。
回调函数¶
from tensorflow import keras
class PrintValTrainRatioCallback(keras.callbacks.Callback):
# 在每个epoch后打印val loss和train loss的比值
# 若想batch-wise操作,自定义on_batch_end()即可
def on_epoch_end(self, epoch, logs):
print("\nval/train: {:.2f}".format(logs["val_loss"] / logs["loss"]))
学习率调度器¶
def exponential_decay(lr0, s):
# 指数学习率
def exponential_decay_fn(epoch):
return lr0 * 0.1 ** (epoch / s)
return exponential_decay_fn
# 第一步:实现一个以epoch为参数的学习率函数
exponential_decay_fn = exponential_decay(lr0=0.01, s=20)
# 第二步:将函数传给keras.callbacks.LearningRateScheduler
# 即可像其他回调函数那样在fit时使用
lr_scheduler = keras.callbacks.LearningRateScheduler(exponential_decay_fn)